import os
import re
import json
import pathlib
import hashlib
import numpy as np
import pandas as pd
from tqdm import tqdm
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from openai import OpenAI, APIError, RateLimitError, APIConnectionError, APITimeoutError

# ========== Parameters to be manually modified ==========
API_KEY = ""
BASE_URL = ""
INPUT_FILES = [
    # r"path/to/your/file1.jsonl",
    # r"path/to/your/file2.jsonl",
]

# ========== Embedding and Field Settings ==========
EMBEDDING_MODEL = "text-embedding-3-large"
DIMENSIONS = None          # Can be set to 1024/512/256 etc. to save space; None=default
BATCH_SIZE = 96
FIELD_ORIG = "min_word_prompt1"         # Original prompt field
FIELD_REWRITE = "seeminglytoxicprompt"  # Rewritten prompt field
# ===================================

# ========== Text Cleaning Options ==========
# Retention policy when removing punctuation:
#   - KEEP_SPACES:    Whether to keep whitespace (\s)
#   - KEEP_CHINESE:   Whether to keep Chinese characters (common Hanzi range \u4e00-\u9fff)
#   - KEEP_UNDERSCORE Whether to keep underscore "_" (part of \w, usually recommended)
KEEP_SPACES = True
KEEP_CHINESE = True
KEEP_UNDERSCORE = True

def build_clean_regex():
    """
    Generates a regex to "keep required characters and remove all others (punctuation, etc.)".
    By default, it keeps [alphanumeric_][whitespace][Chinese] and removes the rest.
    """
    # \w = [A-Za-z0-9_]
    allowed = r"\w"
    if not KEEP_UNDERSCORE:
        # If you don't want to keep underscores, split \w into A-Za-z0-9
        allowed = r"A-Za-z0-9"

    # Whether to keep whitespace
    space = r"\s" if KEEP_SPACES else ""

    # Whether to keep common Chinese characters
    zh = r"\u4e00-\u9fff" if KEEP_CHINESE else ""

    # Construct the final pattern: match a set of "disallowed characters" to be replaced with an empty string
    # Example: [^\w\s\u4e00-\u9fff]
    pattern = rf"[^{allowed}{space}{zh}]"
    return re.compile(pattern)

CLEAN_RE = build_clean_regex()

def clean_text(text: str) -> str:
    """
    Removes all disallowed characters (punctuation, etc.), keeping character classes specified in the configuration.
    """
    if not text:
        return ""
    # First, unify newlines -> spaces to avoid excessive whitespace differences
    text = str(text).replace("\r\n", " ").replace("\n", " ").strip()
    # Remove punctuation and other irrelevant characters
    text = CLEAN_RE.sub("", text)
    # Normalize multiple spaces
    if KEEP_SPACES:
        text = re.sub(r"\s+", " ", text).strip()
    return text

# ========== OpenAI Client ==========
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)

# ========== Utility Functions ==========
def sha1(text: str) -> str:
    return hashlib.sha1(text.encode("utf-8")).hexdigest()

def read_jsonl(path: str):
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rows.append(json.loads(line))
            except json.JSONDecodeError:
                # Ignore bad lines
                continue
    return rows

def build_pairs(rows):
    """
    Builds a list of (pair_id, orig, rewrite) from JSONL records;
    punctuation cleaning is done here.
    """
    pairs = []
    for i, r in enumerate(rows):
        if FIELD_ORIG in r and FIELD_REWRITE in r:
            o_raw = str(r[FIELD_ORIG] or "")
            w_raw = str(r[FIELD_REWRITE] or "")
            # Clean: remove punctuation, normalize whitespace
            o = clean_text(o_raw)
            w = clean_text(w_raw)
            if o and w:
                pairs.append({"pair_id": f"{i:07d}", "orig": o, "rewrite": w})
    return pairs

@retry(
    stop=stop_after_attempt(6),
    wait=wait_exponential(multiplier=1, min=1, max=20),
    retry=retry_if_exception_type((APIError, RateLimitError, APIConnectionError, APITimeoutError)),
    reraise=True
)
def embed_batch(texts):
    kwargs = {"model": EMBEDDING_MODEL, "input": texts}
    if DIMENSIONS:
        kwargs["dimensions"] = int(DIMENSIONS)
    resp = client.embeddings.create(**kwargs)
    # The return order is consistent with the input
    return [d.embedding for d in resp.data]

def embed_unique(texts):
    """
    Generates embeddings in batches after deduplicating texts, returns a { sha1(text) : embedding } map.
    """
    uniq_map = {}     # hash -> text
    uniq_order = []   # Maintain a stable order after deduplication
    for t in texts:
        h = sha1(t)
        if h not in uniq_map:
            uniq_map[h] = t
            uniq_order.append(h)

    vecs = {}
    for i in tqdm(range(0, len(uniq_order), BATCH_SIZE), desc="Embedding"):
        chunk_hashes = uniq_order[i:i+BATCH_SIZE]
        chunk_texts = [uniq_map[h] for h in chunk_hashes]
        embs = embed_batch(chunk_texts)
        for h, v in zip(chunk_hashes, embs):
            vecs[h] = v
    return vecs

def save_outputs(jsonl_path, pairs, orig_vecs, rewrite_vecs):
    """
    Outputs: *.pairs.csv and *.embeddings.npz in the same directory.
    """
    p = pathlib.Path(jsonl_path)
    stem = p.stem
    out_csv = str(p.with_name(f"{stem}.pairs.csv"))
    out_npz = str(p.with_name(f"{stem}.embeddings.npz"))

    df = pd.DataFrame({
        "pair_id":     [x["pair_id"] for x in pairs],
        "orig_len":    [len(x["orig"]) for x in pairs],
        "rewrite_len": [len(x["rewrite"]) for x in pairs],
        "orig":        [x["orig"] for x in pairs],
        "rewrite":     [x["rewrite"] for x in pairs],
    })
    df.to_csv(out_csv, index=False, encoding="utf-8")

    pair_ids = np.array([x["pair_id"] for x in pairs], dtype=object)
    np.savez_compressed(
        out_npz,
        orig_vectors=np.asarray(orig_vecs, dtype=np.float32),
        rewrite_vectors=np.asarray(rewrite_vecs, dtype=np.float32),
        pair_ids=pair_ids
    )

    print(f"Saved: {out_csv}")
    print(f"Saved: {out_npz}\n")

def process_one(path):
    print(f"\n=== Processing {path} ===")
    rows = read_jsonl(path)
    pairs = build_pairs(rows)
    if not pairs:
        print("No valid (orig, rewrite) pairs found (may be empty after cleaning). Please check fields or data.")
        return

    # Combine all texts for deduplicated requests
    all_texts = [x["orig"] for x in pairs] + [x["rewrite"] for x in pairs]
    vec_map = embed_unique(all_texts)

    dim = len(next(iter(vec_map.values())))
    orig_mat = np.zeros((len(pairs), dim), dtype=np.float32)
    rew_mat  = np.zeros_like(orig_mat)

    for i, x in enumerate(pairs):
        orig_mat[i] = vec_map[sha1(x["orig"])]
        rew_mat[i]  = vec_map[sha1(x["rewrite"])]

    save_outputs(path, pairs, orig_mat, rew_mat)
    print(f"Total {len(pairs)} pairs, vector dimension {dim}")

if __name__ == "__main__":
    for f in INPUT_FILES:
        process_one(f)
    print("\nAll completed.")